{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Index and Shape Expressions\n",
    "\n",
    "You already know that a shape can be a tuple of symbols such as `(n, m)` and the elements can be accessed via indexing, e.g. `a[i, j]`. In practice, both shapes and indices may be computed through complex expressions. We will go through several examples in this section.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [],
   "source": [
    "import d2ltvm\n",
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import te"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2
   },
   "source": [
    "## Matrix Transpose\n",
    "\n",
    "Our first example is matrix transpose `a.T`, in which we access `a`'s elements by columns.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "16"
    },
    "origin_pos": 3,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "produce b {\n",
       "  for (i, 0, m) {\n",
       "    for (j, 0, n) {\n",
       "      b[((i*stride) + (j*stride))] = a[((j*stride) + (i*stride))]\n",
       "    }\n",
       "  }\n",
       "}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n = te.var('n')\n",
    "m = te.var('m')\n",
    "A = te.placeholder((n, m), name='a')\n",
    "B = te.compute((m, n), lambda i, j: A[j, i], 'b')\n",
    "s = te.create_schedule(B.op)\n",
    "tvm.lower(s, [A, B], simple_mode=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "Note that the 2-D index, e.g. `b[i,j]` are collapsed to the 1-D index `b[((i*n) + j)]` to follow the C convention.\n",
    "\n",
    "Now verify the results.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "21"
    },
    "origin_pos": 5,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.  1.  2.  3.]\n",
      " [ 4.  5.  6.  7.]\n",
      " [ 8.  9. 10. 11.]]\n",
      "[[ 0.  4.  8.]\n",
      " [ 1.  5.  9.]\n",
      " [ 2.  6. 10.]\n",
      " [ 3.  7. 11.]]\n"
     ]
    }
   ],
   "source": [
    "a = np.arange(12, dtype='float32').reshape((3, 4))\n",
    "b = np.empty((4, 3), dtype='float32')\n",
    "a, b = tvm.nd.array(a), tvm.nd.array(b)\n",
    "\n",
    "mod = tvm.build(s, [A, B])\n",
    "mod(a, b)\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "## Reshaping\n",
    "\n",
    "Next let's use expressions for indexing. The following code block reshapes a 2-D array `a` (`n` by `m` as defined above) to 1-D (just like `a.reshape(-1)` in NumPy). Note how we convert the 1-D index `i` to the 2-D index `[i//m, i%m]`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "28"
    },
    "origin_pos": 7,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "produce b {\n",
       "  for (i, 0, (m*n)) {\n",
       "    b[i] = a[((floordiv(i, m)*stride) + (floormod(i, m)*stride))]\n",
       "  }\n",
       "}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "B = te.compute((m*n, ), lambda i: A[i//m, i%m], name='b')\n",
    "s = te.create_schedule(B.op)\n",
    "tvm.lower(s, [A, B], simple_mode=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "Since an $n$-D array is essentially listed in the memory as a 1-D array, the generated code does not rearrange the data sequence, but it simplifies the index expression from 2-D (`(i//m)*m + i%m`) to 1-D (`i`) to improve the efficiency.\n",
    "\n",
    "We can implement a general 2-D reshape function as well.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "31"
    },
    "origin_pos": 9,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "produce b {\n",
       "  for (i, 0, p) {\n",
       "    for (j, 0, q) {\n",
       "      b[((i*stride) + (j*stride))] = a[((floordiv(((i*q) + j), m)*stride) + (floormod(((i*q) + j), m)*stride))]\n",
       "    }\n",
       "  }\n",
       "}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p, q = te.var('p'), te.var('q')\n",
    "B = te.compute((p, q), lambda i, j: A[(i*q+j)//m, (i*q+j)%m], name='b')\n",
    "s = te.create_schedule(B.op)\n",
    "tvm.lower(s, [A, B], simple_mode=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "When testing the results, we should be aware that we put no constraint on the output shape, which can have an arbitrary shape `(p, q)`, and therefore TVM will not be able to check if $qp = nm$ for us. For example, in the following example we created a `b` with size (20) larger than `a` (12), then only the first 12 elements in `b` are from `a`, others are uninitialized values.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "origin_pos": 11,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.000000e+00 1.000000e+00 2.000000e+00 3.000000e+00]\n",
      " [4.000000e+00 5.000000e+00 6.000000e+00 7.000000e+00]\n",
      " [8.000000e+00 9.000000e+00 1.000000e+01 1.100000e+01]\n",
      " [8.722636e-23 3.066461e-41 9.108440e-44 0.000000e+00]\n",
      " [8.743578e-23 3.066461e-41 8.741932e-23 3.066461e-41]]\n"
     ]
    }
   ],
   "source": [
    "mod = tvm.build(s, [A, B])\n",
    "a = np.arange(12, dtype='float32').reshape((3,4))\n",
    "b = np.zeros((5, 4), dtype='float32')\n",
    "a, b = tvm.nd.array(a), tvm.nd.array(b)\n",
    "\n",
    "mod(a, b)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "## Slicing\n",
    "\n",
    "Now let's consider a special slicing operator `a[bi::si, bj::sj] ` where `bi`, `bj`, `si` and `sj` can be specified later. Now the output shape needs to be computed based on the arguments. In addition, we need to pass the variables `bi`, `bj`, `si` and `sj` as arguments when compiling the module.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "origin_pos": 13,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [],
   "source": [
    "bi, bj, si, sj = [te.var(name) for name in ['bi', 'bj', 'si', 'sj']]\n",
    "B = te.compute(((n-bi)//si, (m-bj)//sj), lambda i, j: A[i*si+bi, j*sj+bj], name='b')\n",
    "s = te.create_schedule(B.op)\n",
    "mod = tvm.build(s, [A, B, bi, si, bj, sj])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "Now test two cases to verify the correctness.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "origin_pos": 15,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [],
   "source": [
    "b = tvm.nd.array(np.empty((1, 3), dtype='float32'))\n",
    "mod(a, b, 1, 2, 1, 1)\n",
    "np.testing.assert_equal(b.asnumpy(), a.asnumpy()[1::2, 1::1])\n",
    "\n",
    "b = tvm.nd.array(np.empty((1, 2), dtype='float32'))\n",
    "mod(a, b, 2, 1, 0, 2)\n",
    "np.testing.assert_equal(b.asnumpy(), a.asnumpy()[2::1, 0::2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "## Summary\n",
    "\n",
    "- Both shape dimensions and indices can be expressions with variables.\n",
    "- If a variable doesn't only appear in the shape tuple, we need to pass it as an argument when compiling.\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}